{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from com.yahoo.ml.caffe.DisplayUtils import *\n",
    "from com.yahoo.ml.caffe.CaffeOnSpark import *\n",
    "from com.yahoo.ml.caffe.Config import *\n",
    "from com.yahoo.ml.caffe.DataSource import *\n",
    "import caffe\n",
    "from caffe import layers as L, params as P\n",
    "from caffe.proto import caffe_pb2\n",
    "from caffe import TRAIN, TEST\n",
    "net_path = '/Users/mridul/bigml/CaffeOnSpark/data/lenet_dataframe_train_test.prototxt'\n",
    "solver_path = '/Users/mridul/bigml/CaffeOnSpark/data/lenet_dataframe_solver.prototxt'\n",
    "training_source = '/Users/mridul/bigml/mnist_train_dataframe'\n",
    "test_source = '/Users/mridul/bigml/mnist_test_dataframe'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cos=CaffeOnSpark(sc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def lenet(net_path, training_source, test_source, training_batch_size, test_batch_size):\n",
    "    n = caffe.NetSpec()\n",
    "    n.data, n.label = L.MemoryData(batch_size=training_batch_size, channels=1,height=28,width=28, \n",
    "                          source = training_source,\n",
    "                          share_in_parallel = False,\n",
    "                          source_class=\"com.yahoo.ml.caffe.ImageDataFrame\",\n",
    "                          transform_param=dict(scale=0.00390625),\n",
    "                          include=dict(phase=TRAIN),ntop=2)\n",
    "\n",
    "    train = str(n.to_proto())\n",
    "    n.data, n.label = L.MemoryData(batch_size=test_batch_size, channels=1,height=28,width=28, \n",
    "                          source = test_source,\n",
    "                          share_in_parallel = False,\n",
    "                          source_class=\"com.yahoo.ml.caffe.ImageDataFrame\",\n",
    "                          transform_param=dict(scale=0.00390625),\n",
    "                          include=dict(phase=TEST),ntop=2)\n",
    "\n",
    "    n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=20, weight_filler=dict(type='xavier'),\n",
    "                            bias_filler=dict(type='constant'),\n",
    "                            param=[dict(lr_mult=1),dict(lr_mult=2)])\n",
    "    n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX)\n",
    "    n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=50, weight_filler=dict(type='xavier'),\n",
    "                            bias_filler=dict(type='constant'))\n",
    "    n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)\n",
    "    n.ip1 =   L.InnerProduct(n.pool2, num_output=500, weight_filler=dict(type='xavier'),\n",
    "                            bias_filler=dict(type='constant'))\n",
    "    n.relu1 = L.ReLU(n.ip1, in_place=True)\n",
    "    n.ip2 = L.InnerProduct(n.relu1, num_output=10, weight_filler=dict(type='xavier'),\n",
    "                          bias_filler=dict(type='constant'),param=[dict(lr_mult=1),dict(lr_mult=2)])\n",
    "    n.accuracy = L.Accuracy(n.ip2, n.label,include=dict(phase=1))\n",
    "    n.loss =  L.SoftmaxWithLoss(n.ip2, n.label)\n",
    "\n",
    "    network_layers = str(n.to_proto())\n",
    "    \n",
    "    with open(net_path, 'w') as f:\n",
    "        f.write('name:\"LeNet\"\\n')\n",
    "        f.write(train)\n",
    "        f.write(network_layers)\n",
    "        f.close()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Solver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def solver(solver_path,net_path,max_iter, learning_rate):\n",
    "    s = caffe_pb2.SolverParameter()\n",
    "    \n",
    "    s.net = net_path\n",
    "    s.test_interval = 500\n",
    "    s.test_iter.append(1)\n",
    "    s.max_iter = max_iter     # # of times to update the net (training iterations)\n",
    "    # Set the initial learning rate \n",
    "    s.base_lr = learning_rate\n",
    "    # Set `lr_policy` to define how the learning rate changes during training.\n",
    "    s.lr_policy = 'inv'\n",
    "    s.gamma = 0.0001\n",
    "    s.power = 0.75\n",
    "\n",
    "    # Set other SGD hyperparameters. Setting a non-zero `momentum` takes a\n",
    "    # weighted average of the current gradient and previous gradients to make\n",
    "    # learning more stable. L2 weight decay regularizes learning, to help prevent\n",
    "    # the model from overfitting.\n",
    "    s.momentum = 0.9\n",
    "    s.weight_decay = 5e-4\n",
    "\n",
    "    # Display the current training loss \n",
    "    s.display = 100\n",
    "\n",
    "    # Snapshots are files used to store networks we've trained.  Here, we'll\n",
    "    # snapshot every 10K iterations -- ten times during training.\n",
    "    s.snapshot = 10000\n",
    "    s.snapshot_prefix = 'caffesnapshot'\n",
    "    \n",
    "    # Train on the GPU.  Using the CPU to train large networks is very slow.\n",
    "    s.solver_mode = caffe_pb2.SolverParameter.GPU\n",
    "\n",
    "    with open(solver_path, 'w') as f:\n",
    "        f.write(str(s))\n",
    "        f.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training with batch size 64 & iteration 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "lenet(net_path, training_source,test_source,64,64)\n",
    "solver(solver_path, net_path,100,0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "args={}\n",
    "args['conf']='/Users/mridul/bigml/CaffeOnSpark/data/lenet_dataframe_solver.prototxt'\n",
    "args['model']='lenet.model'\n",
    "args['devices']='1'\n",
    "args['clusterSize']='1'\n",
    "cfg=Config(sc,args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dl_train_source = DataSource(sc).getSource(cfg,True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "cos.train(dl_train_source)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "dl_test_source = DataSource(sc).getSource(cfg,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "test_result1=cos.test(dl_test_source)\n",
    "test_result1['batch_size']=64\n",
    "test_result1['learning_rate']=0.01\n",
    "test_result1['iteration']=100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Retrain with batch size 100 & iteration 200 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lenet(net_path, training_source,test_source,100,100)\n",
    "solver(solver_path, net_path,200,0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dl_train_source = DataSource(sc).getSource(cfg,True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cos.train(dl_train_source)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dl_test_source = DataSource(sc).getSource(cfg,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "test_result2=cos.test(dl_test_source)\n",
    "test_result2['batch_size']=100\n",
    "test_result2['learning_rate']=0.01\n",
    "test_result2['iteration']=200"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare Test1 - Test2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "test_result = [test_result1,test_result2]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Batch Size</th>\n",
       "      <th>Learning Rate</th>\n",
       "      <th>Iteration</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>64</td>\n",
       "      <td>0.01</td>\n",
       "      <td>100</td>\n",
       "      <td>0.920272</td>\n",
       "      <td>0.260897</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>100</td>\n",
       "      <td>0.01</td>\n",
       "      <td>200</td>\n",
       "      <td>0.960800</td>\n",
       "      <td>0.135948</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Batch Size  Learning Rate  Iteration  Accuracy      Loss\n",
       "0          64           0.01        100  0.920272  0.260897\n",
       "1         100           0.01        200  0.960800  0.135948"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = sqlContext.createDataFrame(map(lambda row:(row['batch_size'],\n",
    "                                               row['learning_rate'],\n",
    "                                               row['iteration'],\n",
    "                                               row['accuracy'][0],\n",
    "                                               row['loss'][0]),\n",
    "                                   test_result), [\"Batch Size\", \"Learning Rate\", \"Iteration\",\"Accuracy\", \"Loss\"])\n",
    "t.toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multiple Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "iteration = 200\n",
    "batch_sizes = [16, 32, 64, 128, 256]\n",
    "learning_rates = [0.01, 0.001, 0.0001]\n",
    "test_results=[]\n",
    "for learning_rate in learning_rates:\n",
    "    for batch in batch_sizes:\n",
    "        lenet(net_path, training_source, test_source, batch, batch)\n",
    "        solver(solver_path, net_path,iteration,learning_rate)\n",
    "        dl_train_source = DataSource(sc).getSource(cfg,True)\n",
    "        cos.train(dl_train_source)\n",
    "        dl_test_source = DataSource(sc).getSource(cfg,False)\n",
    "        test_result=cos.test(dl_test_source)\n",
    "        test_result['batch_size']=batch\n",
    "        test_result['learning_rate']=learning_rate\n",
    "        test_result['iteration']=iteration\n",
    "        test_results.append(test_result)\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare Multiple Tests Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Batch Size</th>\n",
       "      <th>Learning Rate</th>\n",
       "      <th>Iteration</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>16</td>\n",
       "      <td>0.0100</td>\n",
       "      <td>200</td>\n",
       "      <td>0.923438</td>\n",
       "      <td>0.247389</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>32</td>\n",
       "      <td>0.0100</td>\n",
       "      <td>200</td>\n",
       "      <td>0.936298</td>\n",
       "      <td>0.207298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>64</td>\n",
       "      <td>0.0100</td>\n",
       "      <td>200</td>\n",
       "      <td>0.926583</td>\n",
       "      <td>0.239692</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>128</td>\n",
       "      <td>0.0100</td>\n",
       "      <td>200</td>\n",
       "      <td>0.959936</td>\n",
       "      <td>0.138777</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>256</td>\n",
       "      <td>0.0100</td>\n",
       "      <td>200</td>\n",
       "      <td>0.958059</td>\n",
       "      <td>0.145233</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>16</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>200</td>\n",
       "      <td>0.837500</td>\n",
       "      <td>0.567204</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>32</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>200</td>\n",
       "      <td>0.858173</td>\n",
       "      <td>0.501626</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>64</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>200</td>\n",
       "      <td>0.868790</td>\n",
       "      <td>0.462371</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>128</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>200</td>\n",
       "      <td>0.886619</td>\n",
       "      <td>0.431598</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>256</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>200</td>\n",
       "      <td>0.879831</td>\n",
       "      <td>0.461506</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>16</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>200</td>\n",
       "      <td>0.422500</td>\n",
       "      <td>2.096768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>32</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>200</td>\n",
       "      <td>0.541066</td>\n",
       "      <td>1.966596</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>64</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>200</td>\n",
       "      <td>0.577123</td>\n",
       "      <td>1.945433</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>128</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>200</td>\n",
       "      <td>0.362981</td>\n",
       "      <td>2.149804</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>256</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>200</td>\n",
       "      <td>0.510794</td>\n",
       "      <td>2.039266</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    Batch Size  Learning Rate  Iteration  Accuracy      Loss\n",
       "0           16         0.0100        200  0.923438  0.247389\n",
       "1           32         0.0100        200  0.936298  0.207298\n",
       "2           64         0.0100        200  0.926583  0.239692\n",
       "3          128         0.0100        200  0.959936  0.138777\n",
       "4          256         0.0100        200  0.958059  0.145233\n",
       "5           16         0.0010        200  0.837500  0.567204\n",
       "6           32         0.0010        200  0.858173  0.501626\n",
       "7           64         0.0010        200  0.868790  0.462371\n",
       "8          128         0.0010        200  0.886619  0.431598\n",
       "9          256         0.0010        200  0.879831  0.461506\n",
       "10          16         0.0001        200  0.422500  2.096768\n",
       "11          32         0.0001        200  0.541066  1.966596\n",
       "12          64         0.0001        200  0.577123  1.945433\n",
       "13         128         0.0001        200  0.362981  2.149804\n",
       "14         256         0.0001        200  0.510794  2.039266"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = sqlContext.createDataFrame(map(lambda row:(row['batch_size'],\n",
    "                                               row['learning_rate'],\n",
    "                                               row['iteration'],\n",
    "                                               row['accuracy'][0],\n",
    "                                               row['loss'][0]),\n",
    "                                   test_results), [\"Batch Size\", \"Learning Rate\", \"Iteration\",\"Accuracy\", \"Loss\"])\n",
    "t.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
